import random
from typing import Dict
import jpype
from jpype import JImplements, JOverride
from pipelines.prompta.learner.java_utils.dfa import TTTLearnerDFA
from pipelines.prompta.oracle import BaseOracle
from pipelines.prompta.utils import save_dfa, tuple2word, word2tuple
from prompta.utils.java_libs import AcexAnalyzers, DefaultQuery, Word
from .base_learner import BaseLearner
from .l_star import LStarLearner
from .ttt import TTTLearner


class LEARNANYWAYLearner(BaseLearner):
    ID = "LEARNANYWAY"
    def __init__(self, oracle: BaseOracle, exp_dir: str):
        super().__init__(oracle, exp_dir)
        self.core_learner = TTTLearner(oracle, exp_dir).learner
        self.oracle = oracle
        self.exp_dir = exp_dir
        self.reset()

    def reset(self):
        self.core_learner = TTTLearner(self.oracle, self.exp_dir).learner

    def learn(self):
        self.core_learner.startLearning()
        while True:
            hypothesis = self.core_learner.getHypothesisModel()
            ce = self.check_conjecture(hypothesis)
            if ce is None:
                break
            self.update_cache(ce)

            try:
                self.core_learner.refineHypothesis(ce)
            except:
                self.reset()
                self.core_learner.startLearning()
        save_dfa(self.get_dfa_save_path(), hypothesis, self.oracle.jalphabet)
        
        return hypothesis
    
    def check_conjecture(self, hypothesis):
        ce = self.oracle.check_conjecture(hypothesis, 'DefaultQuery')
        return ce
        
    def update_cache(self, ce):
        ce_word = word2tuple(ce.getInput())
        
        # Add the original ce to the cache
        self.oracle.membership_query_cache[ce_word] = ce.getOutput()


